#!/usr/bin/python2

from network import network
from random import randint

import pprint
import img_orient as i_o
import glob
import os

#classifycation definitions
directories = ['rock', 'paper', 'scissors', 'green']

#initialize the network
net = network(18,1,11,4) 					#inputs, hidden_layers, hidden_neurons, outputs

#returns either 'rock', 'paper' or 'scissors' as strings
def rps_classify():
    truth_in = {}
    truth_out = {}

    net.loadWeights("rps_weigths.txt")

    truth_in[0] = i_o.get_converted_img('nao', 0)

    net.calcOuts(truth_in[0])
    print net.outs[2]
    
    for i in range(len(net.outs[2])):
        if net.outs[2][i] > 0.5:
            idx = net.outs[2].index(max(net.outs[2]))
            return directories[idx]
    #return failure otherwise
    return 'nothing'


def rps_learn():
    truth_in = {}
    truth_out = {}
    z = 0
    chosen = [0] * 4

    for z in range(0,75):
    	#random dir
    	dir = directories[randint(0,3)]

    	#import all pictures in dir.
    	for root, dirs, files in os.walk('pics/'+dir):
    		print(len(files))

    	file = files[randint(0,len(files)-1)]

    	bestand = 'pics/'+dir+'/'+file

    	truth_in[z] = i_o.get_converted_img(bestand, 1)

    	#print bestand

    	correct_out = {}

    	#set correct outputs
    	if dir == 'rock':
            correct_out = [1,0,0,0]
            chosen[0] += 1
    	if dir == 'paper':
            correct_out = [0,1,0,0]
            chosen[1] += 1
    	if dir == 'scissors':
            correct_out = [0,0,1,0]
            chosen[2] += 1
        if dir == 'green':
            correct_out = [0,0,0,1]
            chosen[3] += 1
        
        truth_out[z] = correct_out

    	#print truth_in[z]
    	#print truth_out[z]
    print('done', len(files))
    
    print 'chosen images:'
    print chosen
    
    char = raw_input('continue learning with this test set? Y or N')
    if char == 'n':
        print('exiting program')
        return
        
    #truth_in 			= [[0,0],[0,1],[1,0],[1,1]]
    #truth_out 			= [[0],[1],[1],[0]]

    ### Neural Network ###

    #net	                = network(18,1,10,3) 					#inputs, hidden_layers, hidden_neurons, outputs

    net.initWeights()
    #net.loadWeights("wgts_18,11,4_448_2,0.txt")

    net.debug 			= False
    net.alpha			= 1									#Learning rate
    net.adaptive_alpha	= True
    net.alpha_roof		= 1

    #net.calcOuts(truth_in[0])
    #print net.outs[2]

    net.useGraph()
    net.graph 			= True
    net.graphFreq       = 1

    #net.train(truth_in,truth_out,0,50) 			#input_set, output_set, learning_rate, mode, epochs
    #print "Training"
    cnt = net.train(truth_in,truth_out,1,3)		#input_set, output_set, learning_rate, mode, target_sse

    net.saveWeights("rps_weigths.txt")
    #print "Saved weights"

    print('done')

    #net.showNet(True,0)
    #a = i_o.get_converted_img('pics/globe.jpg')
    #print a
    #net.calcOuts(a)
    #print net.outs[1]

    net.calcOuts(truth_in[0])
    print net.outs[2]


if __name__ == '__main__':
    rps_learn()